"""
Tensor切片
- torch.chunk(tensor, chunks, dim=0) 按照某个维度平均分块 最后一个可能小于平均值
- torch.split(tensor, split_size_sections, dim=0) 按照某个维度依照第二个参数给出list或者int进行分割tensor

"""
import torch

a = torch.rand(3, 4)
# print(a)
# out = torch.chunk(a, 2, dim=0)
# print(out)
# print(out[0], out[0].shape)
# print(out[1], out[1].shape)
# out = torch.split(a, 3, dim=1)
# print(out)
# print(out)




